#ifndef WARP_CTC_LOSS_LAYER_HPP
#define WARP_CTC_LOSS_LAYER_HPP

#include <list>
#include <vector>

#include "caffe/layers/loss_layer.hpp"

namespace caffe {
template<typename Dtype>

/**
 * @brief Implementation of the CTC (Connectionist Temporal Classification) algorithm
 *        to label unsegmented sequence data with recurrent neural networks
 *        using the warp-ctc implementation of Baidu
 *
 * Important: The blank label is always at index 0
 *
 * @see CTCLossLayer
 */
class WarpCTCLossLayer: public LossLayer<Dtype> {
public:
	explicit WarpCTCLossLayer(const LayerParameter& param) :
			LossLayer<Dtype>(param), T_(0), N_(0), C_(0) {
		blank_index_ = param.ctc_loss_param().blank_index();
	}
	virtual ~WarpCTCLossLayer();

	virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
			const vector<Blob<Dtype>*>& top);
	virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
			const vector<Blob<Dtype>*>& top);

	virtual inline const char* type() const {
		return "WarpCTCLoss";
	}

	// probabilities, sequence indicators, target sequence
	virtual inline int MinBottomBlobs() const {
		return 2;
	}
	virtual inline int MaxBottomBlobs() const {
		return 4;
	}
	virtual inline int ExactNumBottomBlobs() const {
		return -1;
	}

	// loss
	virtual inline int ExactNumTopBlobs() const {
		return 1;
	}

protected:
	/**
	 * @brief Computes the loss and the error gradients for the input data
	 *        in one step (due to optimization isses)
	 *
	 * @param bottom input Blob vector (length 3)
	 *   -# @f$ (T \times N \times C) @f$
	 *      the inputs @f$ x @f$
	 *   -# @f$ (T \times N) @f$
	 *      the sequence indicators for the data
	 *      (must be 0 at @f$ t = 0 @f$ and 1 during a sequence)
	 *   -# @f$ (T \times N) @f$
	 *      the target sequence
	 *      (must start at @f$ t = 0 @f$ and filled with -1 if the sequence has ended)
	 * @param top output Blob vector (length 1)
	 *   -# @f$ (1) @f$
	 *      the computed loss
	 */

	virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
			const vector<Blob<Dtype>*>& top);
	/**
	 * @brief Unused. Gradient calculation is done in Forward_cpu
	 */
	virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
			const vector<bool>& propagate_down,
			const vector<Blob<Dtype>*>& bottom);

	virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
			const vector<Blob<Dtype>*>& top);

	virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
			const vector<bool>& propagate_down,
			const vector<Blob<Dtype>*>& bottom);

	void ExtractInputData(const Blob<Dtype>* seq_ind_blob,
			const Blob<Dtype>* labels_blob, vector<int>* flat_labels,
			vector<int>* label_lengths, vector<int>* input_lengths);
private:

	int T_;
	int N_;
	int C_;

	// blank index of input sequence, set to -1 for last
	// if set to 0, the 'real' labels must start at 1
	int blank_index_;

	vector<int> flat_labels_;
	vector<int> label_lengths_;
	vector<int> input_lengths_;

	shared_ptr<SyncedMemory> workspace_;

};

}  // namespace caffe

#endif // WARP_CTC_LOSS_LAYER_HPP
